#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os,sys,string,re
import logging
from optparse import OptionParser
import tushare as ts
import numpy as np
import pandas as pd
from sklearn import cluster, covariance, manifold
import matplotlib.pyplot as plt
import matplotlib.font_manager as mfm
from matplotlib.collections import LineCollection

pro = ts.pro_api()
font_path = "/Library/Fonts/Songti.ttc"
prop = mfm.FontProperties(fname=font_path)


logfile = open("spider-running.log", "w")
def log(str):
    print(str)
    logfile.write(str+"\n")

def gen_TsCode(code):
    suffix = ".SZ"
    ncode = int(code[0])
    if ncode >= 6 :
        suffix = ".SH"
    return code + suffix

def request_data(code='000001', start='20180102', end='20190320'):
    ts.get_hist_data(code=code,start=start,end=end).to_csv('data/%s.csv'%code)

def pro_request_data(code='000001', start='20180102', end='20190320'):
    fl = 'data/%s.csv' % code
    if not os.path.isfile(fl):
        db=pro.daily(ts_code=gen_TsCode(code), start_date=start, end_date=end)
        df.ffill(axis = 0, inplace = True) # 填充停牌期间数据
        db.to_csv(fl)
        log("code:%s db-shape:%s" % (code, db.shape))

def spider_data():
    stockDict = {}
    allData = open("hs300-list.txt").readlines()
    for line in allData:
        line = line.strip()
        ss = line.split(',')
        pro_request_data(ss[0])
        stockDict[ss[0]] = ss[1]

    return stockDict


if __name__ == '__main__':
    symbol_dict = spider_data()
    symbols, names = np.array(sorted(symbol_dict.items())).T

    quotes = []
    fields = ['trade_date', 'open', 'close']

    min_x = 100000
    for symbol in symbols:
        print('Fetching quote history for %r' % symbol, file=sys.stderr)
        url = ('./data/{}.csv')
        db = pd.read_csv(url.format(symbol), skipinitialspace=True, usecols=fields, nrows=90)
        x, _ = db.shape
        if x < min_x:
            min_x = x
        quotes.append(db)

    print(min_x)
    close_prices = np.vstack([q['close'] for q in quotes])
    open_prices = np.vstack([q['open'] for q in quotes])

    # The daily variations of the quotes are what carry most information
    variation = close_prices - open_prices


    # #############################################################################
    # Learn a graphical structure from the correlations
    edge_model = covariance.GraphicalLassoCV(cv=5)

    # standardize the time series: using correlations rather than covariance
    # is more efficient for structure recovery
    X = variation.copy().T
    X /= X.std(axis=0)
    edge_model.fit(X)


    # #############################################################################
    # Cluster using affinity propagation

    _, labels = cluster.affinity_propagation(edge_model.covariance_)
    n_labels = labels.max()

    for i in range(n_labels + 1):
        print('Cluster %i: %s' % ((i + 1), ', '.join(names[labels == i])))


    # #############################################################################
    # Find a low-dimension embedding for visualization: find the best position of
    # the nodes (the stocks) on a 2D plane

    # We use a dense eigen_solver to achieve reproducibility (arpack is
    # initiated with random vectors that we don't control). In addition, we
    # use a large number of neighbors to capture the large-scale structure.
    node_position_model = manifold.LocallyLinearEmbedding(
        n_components=2, eigen_solver='dense', n_neighbors=6)

    embedding = node_position_model.fit_transform(X.T).T

    # #############################################################################
    # Visualization
    plt.figure(1, facecolor='w', figsize=(10, 8))
    plt.clf()
    ax = plt.axes([0., 0., 1., 1.])
    plt.axis('off')

    # Display a graph of the partial correlations
    partial_correlations = edge_model.precision_.copy()
    d = 1 / np.sqrt(np.diag(partial_correlations))
    partial_correlations *= d
    partial_correlations *= d[:, np.newaxis]
    non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02)

    # Plot the nodes using the coordinates of our embedding
    plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,
                cmap=plt.cm.nipy_spectral)

    # Plot the edges
    start_idx, end_idx = np.where(non_zero)
    # a sequence of (*line0*, *line1*, *line2*), where::
    #            linen = (x0, y0), (x1, y1), ... (xm, ym)
    segments = [[embedding[:, start], embedding[:, stop]]
                for start, stop in zip(start_idx, end_idx)]
    values = np.abs(partial_correlations[non_zero])
    lc = LineCollection(segments,
                        zorder=0, cmap=plt.cm.hot_r,
                        norm=plt.Normalize(0, .7 * values.max()))
    lc.set_array(values)
    lc.set_linewidths(15 * values)
    ax.add_collection(lc)

    # Add a label to each node. The challenge here is that we want to
    # position the labels to avoid overlap with other labels
    for index, (name, label, (x, y)) in enumerate(
            zip(names, labels, embedding.T)):

        dx = x - embedding[0]
        dx[index] = 1
        dy = y - embedding[1]
        dy[index] = 1
        this_dx = dx[np.argmin(np.abs(dy))]
        this_dy = dy[np.argmin(np.abs(dx))]
        if this_dx > 0:
            horizontalalignment = 'left'
            x = x + .002
        else:
            horizontalalignment = 'right'
            x = x - .002
        if this_dy > 0:
            verticalalignment = 'bottom'
            y = y + .002
        else:
            verticalalignment = 'top'
            y = y - .002
        plt.text(x, y, name, size=10,
                 horizontalalignment=horizontalalignment,
                 verticalalignment=verticalalignment,
                 fontproperties=prop, 
                 bbox=dict(facecolor='w',
                           edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
                           alpha=.6))

    plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(),
             embedding[0].max() + .10 * embedding[0].ptp(),)
    plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(),
             embedding[1].max() + .03 * embedding[1].ptp())

    plt.show()


